import numpy as np
import torch
import random
import math
import torch.nn.functional as F
import matplotlib.pyplot as plt
import re
import torch.nn as nn
import copy
import warnings
import copy

c = torch.arange(1*2*8).reshape(1, 2, 8)
print("c\n", c)
c[..., :2] = torch.arange(2*2).reshape(1, 2, 2)  # 控制单一变量，所以前边的维数要一样只有最后一维不一样
c[..., 2:4] = torch.arange(4, 8).reshape(1, 2, 2)
print('c改\n', c)
a = c[..., :2]
b = c[..., 2:4]
# 说明拼接出来的是先a，再b 再c 不是一个整体 虽然shape最后一样
d = torch.cat((a, b, c[..., 4:]), -1)
print("d\n", d)

